using LinearAlgebra
using DataFrames
using CSV
using DifferentialEquations
using SparseArrays
using KrylovKit
using MKL

include("params-twobody.jl")

evaluation_tsteps = collect(0:floor(Int64,tmax/tstep+1/2))
max_ho_level = 2num_lvls;
trap_depth_left = trap_depth # MHz; nominal in save file
trap_depth_right = trap_depth # MHz; nominal in save file
trap_depth_factor_left = trap_depth_left/4
trap_depth_factor_right = trap_depth_right/4
potential_dir = pwd(); #"/path/to/dir"

include("$functions_path/nacs-tweezer-dvr.jl")
include("$functions_path/nacs-compute-interactions.jl")
include("$functions_path/master-equation-diffeqs.jl")
function coherent_state_number_probability(alpha::T, n::Int64)::Float64 where T <: Number
    return exp(-abs(alpha)^2) * abs(alpha)^(2n)/factorial(n)
end
# beta in units of hbar omega_z
function thermal_state_number_probability(beta::Float64, n::Int64)::Float64
    return (1 - exp(-beta)) * exp(-beta*n)
end

function evolve_rho!(
    rhot_arr::Vector{Matrix{ComplexF64}},
    rho0_arr::Vector{Matrix{ComplexF64}},
    tt::Float64,
    params_arr::Vector{Tuple{SparseMatrixCSC{ComplexF64, Int64}, SparseMatrixCSC{ComplexF64, Int64}, Vector{SparseMatrixCSC{Float64, Int64}}, Vector{SparseMatrixCSC{Float64, Int64}}, Matrix{ComplexF64}, Matrix{ComplexF64}}},
    tol::Float64,
    krylov_maxiter::Int64,
)

    info_arr = []
    for (ii, rho0, params) in Iterators.zip(eachindex(rhot_arr), rho0_arr, params_arr)
        rhot_tmp, info = exponentiate(tt, rho0, tol=tol, maxiter=1000) do rho
            mastereq_dt(rho, params);
        end

        push!(info_arr, info)
        rhot_arr[ii] .= rhot_tmp;
    end

    info_arr
end

function trace_out_motion(rho_arr, energy_submanifold_arr)
    mapreduce(+, Iterators.zip(rho_arr, energy_submanifold_arr)) do (rho, energy_submanifold)
        loc_dim = length(energy_submanifold)
        rho_red = mapreduce(+, 1:loc_dim) do idx
            rho[[idx, loc_dim+idx, 2loc_dim+idx, 3loc_dim+idx],
                [idx, loc_dim+idx, 2loc_dim+idx, 3loc_dim+idx]
                                                                            ]
        end

        rho_red
    end
end


function state_fidelity(rho_arr, energy_submanifold_arr, psi)
    rho_red = trace_out_motion(rho_arr, energy_submanifold_arr)

    psi' * rho_red * psi
end


potential_array_N0 = potential_array_from_h5(string(potential_dir, "/N0$aberration_type.h5"))
potential_array_N1M = potential_array_from_h5(string(potential_dir, "/N1M$aberration_type.h5"))

step_size_arr = compute_step_size_arr(potential_array_N0)
potential_N0_left = potential_gridvalues(potential_array_N0) * trap_depth_factor_left
potential_N0_right = potential_gridvalues(potential_array_N0) * trap_depth_factor_right
potential_N1M_left = potential_gridvalues(potential_array_N1M) * trap_depth_factor_left
potential_N1M_right = potential_gridvalues(potential_array_N1M) * trap_depth_factor_right

hamiltonian_N0_left = dvr_hamiltonian(m_NaCs, step_size_arr, potential_N0_left) / (2pi*hbar);
hamiltonian_N1M_left = dvr_hamiltonian(m_NaCs, step_size_arr, potential_N1M_left) / (2pi*hbar);
hamiltonian_N0_right = dvr_hamiltonian(m_NaCs, step_size_arr, potential_N0_right) / (2pi*hbar);
hamiltonian_N1M_right = dvr_hamiltonian(m_NaCs, step_size_arr, potential_N1M_right) / (2pi*hbar);
@time energies_N0_left, states_N0_left, krylov_info = eigsolve(hamiltonian_N0_left, num_lvls, :SR, Float64, tol=krylov_tol, krylovdim=2num_lvls, maxiter=2000); # eigen(Matrix(hamiltonian_N0_left));
@show krylov_info
flush(stdout)
@assert krylov_info.converged >= num_lvls
energies_N1M_left, states_N1M_left, krylov_info = eigsolve(hamiltonian_N1M_left, num_lvls, :SR, Float64, tol=krylov_tol, krylovdim=2num_lvls, maxiter=2000); # eigen(Matrix(hamiltonian_N1M_left));
@assert krylov_info.converged >= num_lvls
energies_N0_right, states_N0_right, krylov_info = eigsolve(hamiltonian_N0_right, num_lvls, :SR, Float64, tol=krylov_tol, krylovdim=2num_lvls, maxiter=2000); # eigen(Matrix(hamiltonian_N0_right));
@assert krylov_info.converged >= num_lvls
energies_N1M_right, states_N1M_right, krylov_info = eigsolve(hamiltonian_N1M_right, num_lvls, :SR, Float64, tol=krylov_tol, krylovdim=2num_lvls, maxiter=2000); # eigen(Matrix(hamiltonian_N1M_right));
@assert krylov_info.converged >= num_lvls
states_N0_left = hcat(states_N0_left[1:num_lvls]...);
states_N1M_left = hcat(states_N1M_left[1:num_lvls]...);
states_N0_right = hcat(states_N0_right[1:num_lvls]...);
states_N1M_right = hcat(states_N1M_right[1:num_lvls]...);
energies_N0_left   = energies_N0_left[1:num_lvls];
energies_N1M_left  = energies_N1M_left[1:num_lvls];
energies_N0_right  = energies_N0_right[1:num_lvls];
energies_N1M_right = energies_N1M_right[1:num_lvls];
states_N1M_left *= Diagonal(sign.(states_N1M_left'states_N0_left));
states_N0_right *= Diagonal(sign.(states_N0_right'states_N0_left));
states_N1M_right *= Diagonal(sign.(states_N1M_right'states_N0_right));
idx_to_position, position_to_idx = create_idx_pos_dict(potential_array_N0);

if isapprox(trap_depth_left, trap_depth_right)
    @assert isapprox(energies_N0_left, energies_N0_right)
    @assert isapprox(energies_N1M_left, energies_N1M_right)
    @assert isapprox(states_N0_left, states_N0_right)
    @assert isapprox(states_N1M_left, states_N1M_right)
    println("Confirmed identical states...")
    flush(stdout)
end

energy_offset = delta_pulse + 1/2 * (
    states_N0_left[:,1]' * hamiltonian_N0_left * states_N0_left[:,1] +
    states_N0_right[:,1]' * hamiltonian_N0_right * states_N0_right[:,1] -
    states_N1M_left[:,1]' * hamiltonian_N1M_left * states_N1M_left[:,1] -
    states_N1M_right[:,1]' * hamiltonian_N1M_right * states_N1M_right[:,1]
)
# First kron idx: ↑, second kron idx: ↓

energy_submanifold_arr = energy_subspaces(energies_N0_left' .+ energies_N0_right, energy_resolution)
@show energy_submanifold_arr

# in inverse milliseconds (order of dynamics)
println("Computing H_dd...")
@time H_dd_arr, idx_to_state_arr = compute_pulse_phase_twobody_hamiltonian(
    states_N0_left,
    states_N1M_left,
    energies_N0_left,
    energies_N1M_left,
    states_N0_right,
    states_N1M_right,
    energies_N0_right,
    energies_N1M_right,
    num_lvls,
    energy_submanifold_arr,
    0.,
    energy_offset,
    pi/2,
    Float64[tweezer_distance,0,0],
    Float64[1,0,0],
    idx_to_position,
    return_dicts=true
);
flush(stdout)
#@time U_dd_2 = exp(-im * H_dd * tstep);
state_to_idx_arr = [Dict(idx_to_state .=> eachindex(idx_to_state)) for idx_to_state in idx_to_state_arr];
println("Computing H_dd_y...")
@time H_dd_y_arr = compute_pulse_phase_twobody_hamiltonian(
    states_N0_left,
    states_N1M_left,
    energies_N0_left,
    energies_N1M_left,
    states_N0_right,
    states_N1M_right,
    energies_N0_right,
    energies_N1M_right,
    num_lvls,
    energy_submanifold_arr,
    Omega_pulse,
    energy_offset,
    pi/2,
    Float64[tweezer_distance,0,0],
    Float64[1,0,0],
    idx_to_position
);
flush(stdout)
@time H_dd_x_arr = compute_pulse_phase_twobody_hamiltonian(
    states_N0_left,
    states_N1M_left,
    energies_N0_left,
    energies_N1M_left,
    states_N0_right,
    states_N1M_right,
    energies_N0_right,
    energies_N1M_right,
    num_lvls,
    energy_submanifold_arr,
    Omega_pulse,
    energy_offset,
    0.,
    Float64[tweezer_distance,0,0],
    Float64[1,0,0],
    idx_to_position
);

U_pi_2_arr = exp.(im*H_dd_y_arr * pi/(2Omega_pulse));
U_pi_arr = exp.(-im*H_dd_x_arr * pi/(Omega_pulse));
# Note: the Hilbert space is kron(spin1, spin2, motion1, motion2)
L_deph_sp = sqrt(gamma_deph) * Diagonal([1, -1])/2
L_deph_1 = kron(L_deph_sp, I(2), I(num_lvls^2))
L_deph_2 = kron(I(2), L_deph_sp, I(num_lvls^2))
L_deph_arr_arr = [sparse.([kron(L_deph_sp, I(2), I(length(energy_submanifold))), kron(I(2), L_deph_sp, I(length(energy_submanifold)))]) for energy_submanifold in energy_submanifold_arr]
L_deph_dag_arr_arr = [sparse.([kron(L_deph_sp, I(2), I(length(energy_submanifold)))', kron(I(2), L_deph_sp, I(length(energy_submanifold)))']) for energy_submanifold in energy_submanifold_arr]

rho0_arr = zeros.(ComplexF64, size.(U_pi_arr));
hb_om = energies_N0_left[2] - energies_N0_left[1];
beta_w_units = beta/hb_om;
Z_left = sum(exp.(-beta_w_units * energies_N0_left))
Z_right = sum(exp.(-beta_w_units * energies_N0_right))
@time for (ho_level_left, ho_level_right) in Iterators.product(1:num_lvls,1:num_lvls)
    if ho_level_left + ho_level_right - 2 > max_ho_level
        continue
    end
    manifold, idx = find_energy_subspace_idx(energy_submanifold_arr, (ho_level_left, ho_level_right))

    psi0 = zeros(ComplexF64, size(H_dd_arr[manifold], 1));
    psi0[[state_to_idx_arr[manifold]["dd_$(ho_level_left)_$(ho_level_right)"]]] .= 1;
    psi0 = U_pi_2_arr[manifold] * psi0;
    state_probability = 1 / (Z_left * Z_right) * exp(-beta_w_units * (energies_N0_left[ho_level_left] + energies_N0_right[ho_level_right]))
    rho0_arr[manifold] .+= state_probability * psi0 .* psi0'
end

# Not needed if we already work in the rotating frame
bell_fidelity_evol = Float64[0.5];
bell_fidelity2_evol = Float64[0.5];
p00_evol = Float64[0];
p11_evol = Float64[1];
p01_evol = Float64[0];
p10_evol = Float64[0];
lindblad_arr_arr = L_deph_arr_arr
lindblad_dag_arr_arr = L_deph_dag_arr_arr
H_dd_nh_arr = [sparse(2pi * H_dd - im/2 * mapreduce(L -> L'L, +, lindblad_arr)) for (H_dd, lindblad_arr) in Iterators.zip(H_dd_arr, lindblad_arr_arr)];
H_dd_nh_dag_arr = [sparse(2pi * H_dd + im/2 * mapreduce(L -> L'L, +, lindblad_arr)) for (H_dd, lindblad_arr) in Iterators.zip(H_dd_arr, lindblad_arr_arr)];
dummy1_arr = [similar(rho0) for rho0 in rho0_arr];
dummy2_arr = [similar(rho0) for rho0 in rho0_arr];
params_arr = collect(Iterators.zip(H_dd_nh_arr, H_dd_nh_dag_arr, lindblad_arr_arr, lindblad_dag_arr_arr, dummy1_arr, dummy2_arr));


rhot_arr = deepcopy(rho0_arr)
rhot_prev_arr = deepcopy(rho0_arr)
rhot_back_arr = deepcopy(rho0_arr)
@time for (it,tt) in enumerate(tstep:tstep:(tmax+tstep/2))
    println("Time step $it/$(length(evaluation_tsteps)-1)")
    GC.gc(); GC.gc();

    tic = time()
    info_arr = evolve_rho!(rhot_arr, rhot_prev_arr, tstep/2, params_arr, ode_tol, ode_maxiter)
    @assert all([info.converged .>= 1 for info in info_arr])
    rhot_prev_arr .= rhot_arr
    rhot_arr .= U_pi_arr .* rhot_prev_arr .* adjoint.(U_pi_arr)
    info_arr = evolve_rho!(rhot_back_arr, rhot_arr, tt/2, params_arr, 1e-5, ode_maxiter)
    @assert all([info.converged .>= 1 for info in info_arr])

    rhot_arr .= U_pi_2_arr .* rhot_back_arr .* adjoint.(U_pi_2_arr)
    rhot_diag_arr = real(diag.(rhot_arr))

    println("Time taken = $(time() - tic)");
    flush(stdout)

    p11_t = sum([sum(rhot_diag[1:dim_h_motion]) for (rhot_diag, dim_h_motion) in Iterators.zip(rhot_diag_arr, length.(energy_submanifold_arr))])
    p01_t = sum([sum(rhot_diag[dim_h_motion+1:2dim_h_motion]) for (rhot_diag, dim_h_motion) in Iterators.zip(rhot_diag_arr, length.(energy_submanifold_arr))])
    p10_t = sum([sum(rhot_diag[2dim_h_motion+1:3dim_h_motion]) for (rhot_diag, dim_h_motion) in Iterators.zip(rhot_diag_arr, length.(energy_submanifold_arr))])
    p00_t = sum([sum(rhot_diag[3dim_h_motion+1:4dim_h_motion]) for (rhot_diag, dim_h_motion) in Iterators.zip(rhot_diag_arr, length.(energy_submanifold_arr))])
    push!(p00_evol, p00_t)
    push!(p11_evol, p11_t)
    push!(p01_evol, p01_t)
    push!(p10_evol, p10_t)

    bell_fidelity = state_fidelity(rhot_arr, energy_submanifold_arr, [1, 0, 0, im]/sqrt(2))
    bell_fidelity2 = state_fidelity(rhot_arr, energy_submanifold_arr, [1, 0, 0, -im]/sqrt(2))
    @assert isapprox(bell_fidelity, real(bell_fidelity), rtol=1e-8)
    @assert isapprox(bell_fidelity2, real(bell_fidelity2), rtol=1e-8)
    push!(bell_fidelity_evol, real(bell_fidelity))
    push!(bell_fidelity2_evol, real(bell_fidelity2))
end


df = DataFrame(
    t=tstep*evaluation_tsteps,
    P00=p00_evol,
    P11=p11_evol,
    P01=p01_evol,
    P10=p10_evol,
    BellPlusI=bell_fidelity_evol,
    BellMinusI=bell_fidelity2_evol,
);

mkpath(save_dir)

CSV.write("$save_dir/$pstring.csv", df)
